import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
import gym
from external_lib.qrm.src.reward_machines.reward_machine import RewardMachine
from option import Subgoal
import random

class Reacher(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self, training=True, rm=True, dummy_observation=None):
        if dummy_observation is None:
            dummy_observation = np.random.rand(8)
        self.dummy_observation = dummy_observation
        utils.EzPickle.__init__(self)

        xml_file = os.path.join(os.environ['LOF_PKG_PATH'], 'env',
                                'mujoco_env', 'assets', 'reacher.xml')

        mujoco_env.MujocoEnv.__init__(self, xml_file, 2)

        self.all_info = {}
        self.training = training
        self.rm = rm

    def __reduce__(self):
        return (self.__class__, (self.dummy_observation,))
        
    def step(self, a):
        vec = self.get_body_com("fingertip")-self.get_body_com("red")
        reward_dist = - np.linalg.norm(vec)
        reward_ctrl = - np.square(a).sum()
        reward = reward_dist + reward_ctrl

        self.do_simulation(a, self.frame_skip)

        return self.dummy_observation, 0, False, {}
        
    def viewer_setup(self):
        self.viewer.cam.trackbodyid = 0

    def update_all_info(self):
        self.all_info = {
            'jp': self.sim.data.qpos.flat[:2],
            'jv': self.sim.data.qvel.flat[:2],
            'ee_p': self.get_body_com('fingertip')[:2],
            'red_p': self.get_body_com('red')[:2],
            'green_p': self.get_body_com('green')[:2],
            'blue_p': self.get_body_com('blue')[:2],
            'yellow_p': self.get_body_com('yellow')[:2]
        }
        
    def get_info(self):
        return self.all_info

    def get_thetas_for_state(self, state):
        x = state[0]
        y = state[1]
        a1 = 0.1
        a2 = 0.1
        # there are two possible solutions to the IK problem
        # solution 1
        q2_1 = np.arccos((x**2 + y**2 - a1**2 - a2**2)/(2*a1*a2))
        q1_1 = np.arctan2(y, x) - np.arctan2(a2*np.sin(q2_1), a1 + a2*np.cos(q2_1))
        # solution 2
        # q2_2 = -q2_1
        # q1_2 = np.arctan2(y, x) + np.arctan2(a2*np.sin(q2_2), a1 + a2*np.cos(q2_2))

        return [q1_1, q2_1]
        
    def reset_model(self):
        self.sim.reset()

        qpos = self.np_random.uniform(low=-3, high=3, size=self.model.nq) + self.init_qpos
        goal_config = np.random.randint(4)
        # goal_jitter = self.np_random.uniform(low=-0.001, high=0.001, size=8)
        goal1 = np.array([0.1, 0.13])
        goal2 = np.array([-0.1, 0.13])
        goal3 = np.array([0.0, 0.08])
        goal4 = np.array([0.1, -0.1])

        if not self.rm:
            if self.training:
                goal_config = np.random.randint(4)
            else:
                goal_config = 0

            if goal_config == 0: # this is the test configuration
                self.goal = np.hstack((goal1, goal4, goal2, goal3))
                start_pos = random.choice([goal2, goal3, goal4])
            elif goal_config == 1:
                self.goal = np.hstack((goal4, goal1, goal2, goal3))
                start_pos = random.choice([goal2, goal3, goal1])
            elif goal_config == 2:
                self.goal = np.hstack((goal3, goal4, goal2, goal1))
                start_pos = random.choice([goal2, goal1, goal4])
            elif goal_config == 3:
                self.goal = np.hstack((goal2, goal3, goal4, goal1))
                start_pos = random.choice([goal1, goal3, goal4])

            if not self.training:
                qpos[:2] = self.get_thetas_for_state([0.15, 0])
            else:
                if np.random.uniform() < 0.2:
                    start_pos = [0.15, 0]
                qpos[:2] = self.get_thetas_for_state(start_pos)

        else:
            self.goal = np.hstack((goal1, goal4, goal2, goal3))
            if self.training:
                start_pos = random.choice([goal1, goal2, goal3, goal4, np.array([0.15, 0])])
                qpos[:2] = self.get_thetas_for_state(start_pos)
            else:
                qpos[:2] = self.get_thetas_for_state([0.15, 0])


            
        qpos[-2:] = self.goal[-2:]
        qpos[-4:-2] = self.goal[-4:-2]
        qpos[-6:-4] = self.goal[-6:-4]
        qpos[-8:-6] = self.goal[-8:-6]
        
        qvel = self.init_qvel + self.np_random.uniform(low=-1, high=1, size=self.model.nv)
        qvel[-8:] = 0
        self.set_state(qpos, qvel)

        
import numpy as np
import os

default_config = {
    # Common to all envs
    "seed": 10,
    'synchrounous': True,
    "debug": False,
    "state_space": None,
    "action_space": None,
    "get_state": None,
    "get_reward": None,
    "is_done": None,
    # specific
    'headless': True
}

class ReacherEnv(object):

    def __init__(self, training=True, rm=True, config={}, seed=3, logger=None, **kwargs):

        self.ReacherEnv_config = default_config
        self.ReacherEnv_config.update(config)

        self.render = not self.ReacherEnv_config.get('headless')

        dummy_observation = np.random.rand(self.state_space['shape'][0])
        
        self.env = Reacher(training=training,rm=rm,dummy_observation=dummy_observation)

        self.set_seed(seed)
        self.all_info = {}
        
    def update_all_info(self):
        self.env.update_all_info()
        self.all_info = self.env.get_info()

    def get_info(self):
        return self.all_info
        
    def get_state(self, all_info, color='r'):
        if color == 'r':
            target = 'red_p'
        elif color == 'b':
            target = 'blue_p'
        elif color == 'g':
            target = 'green_p'
        elif color == 'y':
            target = 'yellow_p'
        else:
            raise ValueError(f"color {color} not supported")


        mdp_state = np.concatenate([
            np.cos(all_info['jp']),
            np.sin(all_info['jp']),
            all_info[target],
            all_info['jv'],
            all_info['ee_p'] - all_info[target]
        ])
     
            
        return mdp_state

        
    def get_reward(self, state=None, action=None, next_state=None, all_info={}, color='r'):
        all_info = self.env.all_info
        if color == 'r':
            target = 'red_p'
        elif color == 'b':
            target = 'blue_p'
        elif color == 'g':
            target = 'green_p'
        elif color == 'y':
            target = 'yellow_p'
        else:
            raise ValueError("color not supported")

        r_act = - np.square(action).sum()
        r_goal = 0
        goal_dist = np.linalg.norm(all_info['ee_p'] - all_info[target])
        if goal_dist < 0.019:
            r_goal = 0
        else:
            r_goal = - goal_dist
        
        r = r_goal + r_act
        # r = -np.linalg.norm(all_info['ee_p'] - all_info[target]) - np.square(action).sum()
        
        return r
        
    def is_done(self, state=None, action=None, next_state=None, all_info={}):
        return False
        
    def reset(self, **kwargs):
        color = kwargs['color']
        self.env.reset()
        self.update_all_info()

        if self.render:
            self.env.render()

        return self.get_state(self.all_info, color)
        
    def step(self, actions, **kwargs):
        if self.render:
            self.env.render()

        if isinstance(self.env.action_space, gym.spaces.Discrete):
            actions = actions[0]

        self.env.step(actions)
        self.update_all_info()

        
    def set_seed(self, seed=0, **kwargs):
        self.env.seed(seed)

    def close(self, **kwargs):
        self.env.close()

    @property
    def state_space(self):
        state_space = {'type': 'float', 'shape': (10, ), 'upper_bound': [], 'lower_bound': []}
        return state_space
        
    @property
    def action_space(self):
        action_space = {'type': 'float', 'shape': (2, ), "upper_bound": np.ones(2), "lower_bound": -np.ones(2)}
        return action_space
        
    def restore(self):
        '''
        restores the environment
        '''
        pass

class ReacherGymEnv(gym.Env):
    def __init__(self, training=True, rm=False, env_config={}):
        self.env = ReacherEnv(training=training,rm=rm,config=env_config)

        action_dim = self.env.action_space['shape'][0]
        state_dim = self.env.state_space['shape'][0]
        self.action_space = gym.spaces.Box(low=np.array(action_dim*[-0.5]), high=np.array(action_dim*[0.5]), dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=np.array(state_dim*[-200.]), high=np.array(state_dim*[200.]), dtype=np.float32)

        self.env.update_all_info()
        self.all_info = self.env.get_info()

        self.config = env_config
        self.step_num = 0
        
        self.metadata = self.env.env.metadata
        self.task_done = False

    def get_info(self):
        return self.env.get_info()
        
    def reset(self, color='r', epoch=0):
        self.step_num= 0
        self.task_done = False
        return self.env.reset(color=color)

    def set_task_done(self, done):
        self.task_done = done
        
    def step(self, action, color='r'):
        self.step_num += 1

        all_info = self.env.get_info()
        obs = self.env.get_state(all_info, color=color)
        reward = self.env.get_reward(action=action, all_info=all_info, color=color)
     
        self.env.step(action)
        done = self.env.is_done()
        info = {}

        if 'horizon' not in self.config:
            self.config['horizon'] = 50

        if self.step_num >= self.config['horizon'] or self.task_done or done:
            task_done = True
        else:
            task_done = False
        return obs, reward, task_done, info

    def get_current_propositions(self, threshold=0.02):
        self.env.update_all_info()
        all_info = self.env.get_info()
        ee_p = all_info['ee_p']
        red_p = all_info['red_p']
        green_p = all_info['green_p']
        blue_p = all_info['blue_p']
        yellow_p = all_info['yellow_p']

        red_prop = 0
        blue_prop = 0
        green_prop = 0
        yellow_prop = 0
        empty_prop = 0

        if np.linalg.norm(ee_p - red_p) < threshold:
            red_prop = 1
        if np.linalg.norm(ee_p - blue_p) < threshold:
            blue_prop = 1
        if np.linalg.norm(ee_p - green_p) < threshold:
            green_prop = 1
        if np.linalg.norm(ee_p - yellow_p) < threshold:
            yellow_prop = 1

        if red_prop == 0 and blue_prop == 0 and green_prop == 0 and yellow_prop == 0:
            empty_prop = 1 

        return red_prop, green_prop, blue_prop, yellow_prop, empty_prop

    def get_propositions(self, state, threshold=0.02):
        # state = [x, y]
        state = np.array(state)
        # full state = [cos(jp), sin(jp), target, jv, ee_p-target]
        all_info = self.env.get_info()
        red_p = all_info['red_p']
        green_p = all_info['green_p']
        blue_p = all_info['blue_p']
        yellow_p = all_info['yellow_p']

        red_prop = 0
        blue_prop = 0
        green_prop = 0
        yellow_prop = 0
        empty_prop = 0

        if np.linalg.norm(state - red_p) < threshold:
            red_prop = 1
        if np.linalg.norm(state - blue_p) < threshold:
            blue_prop = 1
        if np.linalg.norm(state - green_p) < threshold:
            green_prop = 1
        if np.linalg.norm(state - yellow_p) < threshold:
            yellow_prop = 1

        if red_prop == 0 and blue_prop == 0 and green_prop == 0 and yellow_prop == 0:
            empty_prop = 1

        return red_prop, green_prop, blue_prop, yellow_prop, empty_prop
        
    def render(self, mode=None):
        return self.env.env.render(mode)


class ReacherGymEnvEval(gym.Env):
    def __init__(self, task_spec, cancel_chance=0, training=True, env_config={}):
        self.env = ReacherGymEnv(training=training, rm=False, env_config=env_config)

        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

        self.all_info = self.env.get_info()

        self.config = env_config
        self.step_num = 0
        
        self.metadata = self.env.env.env.metadata
        self.task_done = False

        self.subgoals = self.make_subgoals()
        self.task_spec = task_spec
        self.f = 0

        # can_state is the FSA state during which cancellation may occur
        # can_chance is the chance of cancellation occuring during can_state
        self.can_state = 1
        self.cancelled = (np.random.uniform() < cancel_chance)
        print("Cancelled?: {}".format(self.cancelled))

    def make_subgoals(self):
        # name, prop_index, subgoal_index, state
        red_goal = Subgoal('red', 0, 0, self.all_info['red_p'])
        green_goal = Subgoal('green', 1, 1, self.all_info['green_p'])
        blue_goal = Subgoal('blue', 2, 2, self.all_info['blue_p'])
        yellow_goal = Subgoal('yellow', 3, 3, self.all_info['yellow_p'])

        return [red_goal, green_goal, blue_goal, yellow_goal]
        
    def reset(self, color='r'):
        self.step_num= 0
        self.task_done = False
        self.f = 0
        return self.env.reset(color=color)

    def set_task_done(self, done):
        self.env.set_task_done(done)
        self.task_done = done
        
    def step(self, action, color='r'):
        obs, reward, task_done, info = self.env.step(action, color)

        # reward += -0.1 # small timestep cost
     
        self.f = self.get_fsa_state(self.f)
        goal_state = self.task_spec.nF - 1

        if self.f != goal_state:
            reward = - np.square(action).sum() - 0.1
        else:
            reward = 0

        info['f'] = self.f

        if self.task_done or task_done or self.f == goal_state:
            task_done = True
        else:
            task_done = False
        return obs, reward, task_done, info

    def get_fsa_state(self, f, tm=None, threshold=0.02):
        if tm is None:
            tm = self.task_spec.tm
        
        props = np.array(self.get_current_propositions(threshold))
        p = np.where(np.array(props) == 1)[0][0]
        next_f = np.argmax(tm[f, :, p])
        return next_f

    def augment_props(self, props):
        # [r, g, b, y, e] ==> [r, g, b, y, c, cr, cg, cb, cy, e]
        cancelled = False

        # if self.f == self.can_state:
        #     cancelled = self.cancelled
        if self.f != 0:
            cancelled = self.cancelled

        # deal with empty prop and cancelled
        prop_index = -1
        if props[-1] == 1 and cancelled:
            prop_index = 4

        # deal with the rest of the props
        for i, prop in enumerate(props[:-1]):
            if not cancelled:
                if prop == 1:
                    prop_index = i
            else:
                if prop == 1:
                    prop_index = i+5

        all_props = [0]*10
        all_props[prop_index] = 1

        return all_props

    def get_current_propositions(self, threshold=0.02):
        props = self.env.get_current_propositions(threshold)
        all_props = self.augment_props(props)
        return all_props

    def get_propositions(self, state, threshold=0.02):
        props = self.env.get_propositions(state, threshold)
        all_props = [0] * 10
        all_props[:4] = props[:-1]
        all_props[-1] = props[-1]

        return all_props
        
    def render(self, mode=None):
        return self.env.env.env.render(mode)

class RMReacherGymEnv(gym.Env):
    def __init__(self, nF=0, task_name='', training=True, env_config={}):
        self.env = ReacherEnv(training=training, rm=True, env_config=env_config)

        self.nF = nF
        self.task_name = task_name
        self.training = training

        action_dim = self.env.action_space['shape'][0]
        state_dim = self.env.state_space['shape'][0] + 1
        self.action_space = gym.spaces.Box(low=np.array(action_dim*[-0.5]), high=np.array(action_dim*[0.5]), dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=np.array(state_dim*[-200.]), high=np.array(state_dim*[200.]), dtype=np.float32)

        # self.action_space = gym.spaces.Box(low=np.array([-1, -1]), high=np.array([1,1]), dtype=np.float32)
        # self.observation_space = gym.spaces.Box(low=np.array(15*[-200.]), high=np.array(15*[200.]), dtype=np.float32)
        # self.reset()
        self.env.update_all_info()
        self.all_info = self.env.get_info()

        self.config = env_config
        self.step_num = 0
        self.epoch = None

        self.task_done = False

        self.metadata = self.env.env.metadata

        # initialize RM
        rm_file = os.path.join(os.environ['LOF_PKG_PATH'], 'baselines', 'rm', 'tasks', task_name + '.txt')

        self.rm = RewardMachine(rm_file, use_rs=False, gamma=0.9)
        
    def reset(self, epoch=np.inf):
        self.step_num=0
        self.epoch = epoch
        self.task_done = False
        # start at the second-to-last FSA state
        if epoch == np.inf:
            self.u = 0
        else:
            self.u = epoch % (self.nF - 2)
            self.config['horizon'] = 800 

        mdp_state = self.env.reset(color='r')

        state = np.concatenate([np.array([self.u]), mdp_state])
        return state
        
    def step(self, action, color='r', cancel=False):
        self.step_num += 1
        
        self.env.step(action)
        all_info = self.env.get_info()
        obs_mdp = self.env.get_state(all_info)
        #reward = self.env.get_reward(action=action, all_info=all_info)
        done = self.env.is_done()
        info = {}

        dist_red = np.linalg.norm(all_info['ee_p'] - all_info['red_p'])
        dist_green = np.linalg.norm(all_info['ee_p'] - all_info['green_p'])
        dist_blue = np.linalg.norm(all_info['ee_p'] - all_info['blue_p'])
        dist_yellow = np.linalg.norm(all_info['ee_p'] - all_info['yellow_p'])

        true_props = []
        if dist_red < 0.02:
            true_props.append('r')
        if dist_green < 0.02:
            true_props.append('g')
        if dist_blue < 0.02:
            true_props.append('b')
        if dist_yellow < 0.02:
            true_props.append('y')
        if cancel:
            true_props.append('c')
            
        next_u = self.rm.get_next_state(self.u, true_props)

        if next_u != self.u:
            print("FSA STATE CHANGED FROM {} TO {}".format(self.u, next_u))

        if next_u != self.nF - 1:
            reward = - np.square(action).sum() - 0.1
        else:
            reward = 0

        if self.training:
            if next_u != self.u:
                reward += 10

        # reward = self.rm.get_reward(self.u, next_u, 0,0,0, False)
        # reward -= 0.1

        #print(f"dist_red: {dist_red}, dist_blue: {dist_blue}, true_prop: {true_props}, u: {self.u}, next_u: {next_u}, reward: {reward}")
        
        self.u = next_u
        obs = np.concatenate([np.array([self.u]), obs_mdp])
            
        if self.step_num >= self.config['horizon'] or self.task_done:
            done = True
        return obs, reward, done, info

    def render(self, mode=None):
        return self.env.env.render(mode)

    def set_task_done(self, done):
        self.task_done = done

        
def reacher_env_creator(env_config):
    return ReacherGymEnv(env_config)  # return an env instance
    #return RMReacherGymEnv(env_config)  # return an env instance

        